

\section{自定义数据读取
}\label{ux81eaux5b9aux4e49ux6570ux636eux8bfbux53d6}

基本要求:

\begin{itemize}
\tightlist
\item
  熟悉 C++ 编程。
\item
  确保\href{tensorflow-zh/SOURCE/get_started/os_setup.md\#source}{下载
  TensorFlow 源文件}, 并可编译使用。
\end{itemize}

我们将支持文件格式的任务分成两部分：

\begin{itemize}
\tightlist
\item
  文件格式: 我们使用 \emph{Reader} Op来从文件中读取一个 \emph{record}
  (可以使任意字符串)。
\item
  记录格式:
  我们使用解码器或者解析运算将一个字符串记录转换为TensorFlow可以使用的张量。
\end{itemize}

例如， 读取一个
\href{https://en.wikipedia.org/wiki/Comma-separated_values}{CSV
文件}，我们使用
\href{tensorflow-zh/SOURCE/api_docs/python/io_ops.md\#TextLineReader}{一个文本读写器}，
然后是\href{tensorflow-zh/SOURCE/api_docs/python/io_ops.md\#decode_csv}{从一行文本中解析CSV数据的运算}。

\subsection{主要内容}\label{ux4e3bux8981ux5185ux5bb9}

\subsubsection{\texorpdfstring{\protect\hyperlink{AUTOGENERATED-custom-data-readers}{自定义数据读取}}{自定义数据读取}}\label{ux81eaux5b9aux4e49ux6570ux636eux8bfbux53d6-1}

\begin{itemize}
\tightlist
\item
  \protect\hyperlink{AUTOGENERATED-writing-a-reader-for-a-file-format}{编写一个文件格式读写器}
\item
  \protect\hyperlink{AUTOGENERATED-writing-an-op-for-a-record-format}{编写一个记录格式Op}
\end{itemize}

\subsection{编写一个文件格式读写器
}\label{ux7f16ux5199ux4e00ux4e2aux6587ux4ef6ux683cux5f0fux8bfbux5199ux5668}

Reader
是专门用来读取文件中的记录的。TensorFlow中内建了一些读写器Op的实例：

\begin{itemize}
\tightlist
\item
  \href{tensorflow-zh/SOURCE/api_docs/python/io_ops.md\#TFRecordReader}{tf.TFRecordReader}
  (\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/kernels/tf_record_reader_op.cc}{代码位于kernels/tf\_record\_reader\_op.cc})
\item
  \href{tensorflow-zh/SOURCE/api_docs/python/io_ops.md\#FixedLengthRecordReader}{tf.FixedLengthRecordReader}
  (\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/kernels/fixed_length_record_reader_op.cc}{代码位于
  kernels/fixed\_length\_record\_reader\_op.cc})
\item
  \href{tensorflow-zh/SOURCE/api_docs/python/io_ops.md\#TextLineReader}{tf.TextLineReader}
  (\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/kernels/text_line_reader_op.cc}{代码位于
  kernels/text\_line\_reader\_op.cc})
\end{itemize}

你可以看到这些读写器的界面是一样的，唯一的差异是在它们的构造函数中。最重要的方法是
Read。
它需要一个行列参数，通过这个行列参数，可以在需要的时候随时读取文件名
(例如： 当 Read Op首次运行，或者前一个 Read`
从一个文件中读取最后一条记录时)。它将会生成两个标量张量：
一个字符串和一个字符串关键值。

新创建一个名为 SomeReader 的读写器，需要以下步骤：

\begin{enumerate}
\def\labelenumi{\arabic{enumi}.}
\tightlist
\item
  在 C++ 中, 定义一个
  \href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/kernels/reader_base.h}{tensorflow::ReaderBase}的子类，命名为
  ``SomeReader''.
\item
  在 C++ 中，注册一个新的读写器Op和Kernel，命名为 ``SomeReader''。
\item
  在 Python 中, 定义一个
  \href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/python/ops/io_ops.py}{tf.ReaderBase}
  的子类，命名为 ``SomeReader''。
\end{enumerate}

你可以把所有的 C++ 代码放在
\texttt{tensorflow/core/user\_ops/some\_reader\_op.cc}文件中.
读取文件的代码将被嵌入到C++ 的 ReaderBase 类的迭代中。 这个 ReaderBase
类 是在
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/kernels/reader_base.h}{tensorflow/core/kernels/reader\_base.h}
中定义的。 你需要执行以下的方法：

\begin{itemize}
\tightlist
\item
  OnWorkStartedLocked：打开下一个文件
\item
  ReadLocked：读取一个记录或报告 EOF/error
\item
  OnWorkFinishedLocked：关闭当前文件
\item
  ResetLocked：清空记录，例如：一个错误记录
\end{itemize}

以上这些方法的名字后面都带有 ``Locked''， 表示 ReaderBase
在调用任何一个方法之前确保获得互斥锁，这样就不用担心线程安全（虽然只保护了该类中的元素而不是全局的）。

对于 OnWorkStartedLocked, 需要打开的文件名是 \texttt{current\_work()}
函数的返回值。此时的 ReadLocked 的数字签名如下:

\begin{verbatim}
Status ReadLocked(string* key, string* value, bool* produced, bool* at_end)
\end{verbatim}

如果 ReadLocked 从文件中成功读取了一条记录，它将更新为：

\begin{itemize}
\tightlist
\item
  *key： 记录的标志位，通过该标志位可以重新定位到该记录。 可以包含从
  current\_work() 返回值获得的文件名，并追加一个记录号或其他信息。
\item
  *value： 包含记录的内容。
\item
  *produced： 设置为 true。
\end{itemize}

当你在文件（EOF）末尾，设置 *at\_end 为 true ，在任何情况下，都将返回
Status::OK()。 当出现错误的时候，只需要使用
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/lib/core/errors.h}{tensorflow/core/lib/core/errors.h}
中的一个辅助功能就可以简单地返回，不需要做任何参数修改。

接下来你讲创建一个实际的读写器Op。
如果你已经熟悉了\href{tensorflow-zh/SOURCE/how_tos/adding_an_op/index.md}{添加新的Op}
那会很有帮助。 主要步骤如下：

\begin{itemize}
\tightlist
\item
  注册Op。
\item
  定义并注册 OpKernel。
\end{itemize}

要注册Op，你需要用到一个调用指令定义在
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/op.h}{tensorflow/core/framework/op.h}中的REGISTER\_OP。

读写器 Op 没有输入，只有 Ref(string) 类型的单输出。它们调用
SetIsStateful()，并有一个 container 字符串和 shared\_name 属性.
你可以在一个 Doc 中定义配置或包含文档的额外属性。 例如：详见
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/ops/io_ops.cc}{tensorflow/core/ops/io\_ops.cc}等：

\begin{verbatim}
 #include "tensorflow/core/framework/op.h"
REGISTER_OP("TextLineReader")
    .Output("reader_handle: Ref(string)")
    .Attr("skip_header_lines: int = 0")
    .Attr("container: string = ''")
    .Attr("shared_name: string = ''")
    .SetIsStateful()
    .Doc(R"doc(
A Reader that outputs the lines of a file delimited by '\n'.
)doc");
\end{verbatim}

要定义一个 OpKernel，
读写器可以使用定义在\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/reader_op_kernel.h}{tensorflow/core/framework/reader\_op\_kernel.h}中的
ReaderOpKernel 的递减快捷方式，并运行一个叫 SetReaderFactory
的构造函数。 定义所需要的类之后，你需要通过
REGISTER\_KERNEL\_BUILDER(\ldots{}) 注册这个类。

一个没有属性的例子：

\begin{verbatim}
 #include "tensorflow/core/framework/reader_op_kernel.h"
class TFRecordReaderOp : public ReaderOpKernel {
 public:
  explicit TFRecordReaderOp(OpKernelConstruction* context)
      : ReaderOpKernel(context) {
    Env* env = context->env();
    SetReaderFactory([this, env]() { return new TFRecordReader(name(), env); });
  }
};
REGISTER_KERNEL_BUILDER(Name("TFRecordReader").Device(DEVICE_CPU),
                        TFRecordReaderOp);
\end{verbatim}

一个带有属性的例子：

\begin{verbatim}
 #include "tensorflow/core/framework/reader_op_kernel.h"
class TextLineReaderOp : public ReaderOpKernel {
 public:
  explicit TextLineReaderOp(OpKernelConstruction* context)
      : ReaderOpKernel(context) {
    int skip_header_lines = -1;
    OP_REQUIRES_OK(context,
                   context->GetAttr("skip_header_lines", &skip_header_lines));
    OP_REQUIRES(context, skip_header_lines >= 0,
                errors::InvalidArgument("skip_header_lines must be >= 0 not ",
                                        skip_header_lines));
    Env* env = context->env();
    SetReaderFactory([this, skip_header_lines, env]() {
      return new TextLineReader(name(), skip_header_lines, env);
    });
  }
};
REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU),
                        TextLineReaderOp);
\end{verbatim}

最后一步是添加 Python 包装器，你需要将 tensorflow.python.ops.io\_ops
导入到
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/python/user_ops/user_ops.py}{tensorflow/python/user\_ops/user\_ops.py}，并添加一个
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/python/ops/io_ops.py}{io\_ops.ReaderBase}的衍生函数。

\begin{verbatim}
from tensorflow.python.framework import ops
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import io_ops
class SomeReader(io_ops.ReaderBase):
    def __init__(self, name=None):
        rr = gen_user_ops.some_reader(name=name)
        super(SomeReader, self).__init__(rr)
ops.NoGradient("SomeReader")
ops.RegisterShape("SomeReader")(common_shapes.scalar_shape)
\end{verbatim}

你可以在
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/python/ops/io_ops.py}{tensorflow/python/ops/io\_ops.py}中查看一些范例。

\subsection{编写一个记录格式Op
}\label{ux7f16ux5199ux4e00ux4e2aux8bb0ux5f55ux683cux5f0fop}

一般来说，这是一个普通的Op， 需要一个标量字符串记录作为输入， 因此遵循
\href{tensorflow-zh/SOURCE/how_tos/adding_an_op/index.md}{添加Op的说明}。
你可以选择一个标量字符串作为输入，
并包含在错误消息中报告不正确的格式化数据。

用于解码记录的运算实例：

\begin{itemize}
\tightlist
\item
  \href{tensorflow-zh/SOURCE/api_docs/python/io_ops.md\#parse_single_example}{tf.parse\_single\_example}
  (and
  \href{tensorflow-zh/SOURCE/api_docs/python/io_ops.md\#parse_example}{tf.parse\_example})
\item
  \href{tensorflow-zh/SOURCE/api_docs/python/io_ops.md\#decode_csv}{tf.decode\_csv}
\item
  \href{tensorflow-zh/SOURCE/api_docs/python/io_ops.md\#decode_raw}{tf.decode\_raw}
\end{itemize}

请注意，使用多个Op 来解码某个特定的记录格式也是有效的。
例如，你有一张以字符串格式保存在
\href{https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/example/example.proto}{tf.train.Example
协议缓冲区}的图像文件。 根据该图像的格式， 你可能从
\href{tensorflow-zh/SOURCE/api_docs/python/io_ops.md\#parse_single_example}{tf.parse\_single\_example}
的Op 读取响应输出并调用
\href{tensorflow-zh/SOURCE/api_docs/python/image.md\#decode_jpeg}{tf.decode\_jpeg}，
\href{tensorflow-zh/SOURCE/api_docs/python/image.md\#decode_png}{tf.decode\_png}，
或者
\href{tensorflow-zh/SOURCE/api_docs/python/io_ops.md\#decode_raw}{tf.decode\_raw}。通过读取
tf.decode\_raw
的响应输出并使用\href{tensorflow-zh/SOURCE/api_docs/python/array_ops.md\#slice}{tf.slice}
和
\href{tensorflow-zh/SOURCE/api_docs/python/array_ops.md\#reshape}{tf.reshape}
来提取数据是通用的方法。 \textgreater{}
原文：\href{http://tensorflow.org/how_tos/new_data_formats/index.html\#custom-data-readers}{Custom
Data Readers} 翻译：{[}@derekshang{]}(https://github.com/derekshang)
校对：\href{https://github.com/jikexueyuanwiki}{Wiki}